package com.wys.spring.db;

import cn.hutool.db.ds.simple.AbstractDataSource;
import cn.hutool.setting.yaml.YamlUtil;
import cn.zhxu.bs.SqlExecutor;
import cn.zhxu.bs.implement.DefaultSqlExecutor;
import com.alibaba.cloud.nacos.NacosConfigManager;
import com.alibaba.cloud.nacos.NacosConfigProperties;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.spring.boot.autoconfigure.DruidDataSourceAutoConfigure;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.TypeReference;
import com.alibaba.nacos.api.config.ConfigService;
import com.alibaba.nacos.api.config.listener.AbstractSharedListener;
import com.alibaba.nacos.api.exception.NacosException;
import com.alibaba.nacos.common.executor.NameThreadFactory;
import com.baomidou.mybatisplus.autoconfigure.MybatisPlusAutoConfiguration;
import com.baomidou.mybatisplus.autoconfigure.MybatisPlusProperties;
import com.baomidou.mybatisplus.autoconfigure.SpringBootVFS;
import com.baomidou.mybatisplus.core.MybatisConfiguration;
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.baomidou.mybatisplus.core.incrementer.IKeyGenerator;
import com.baomidou.mybatisplus.core.incrementer.IdentifierGenerator;
import com.baomidou.mybatisplus.core.injector.ISqlInjector;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.github.yulichang.injector.MPJSqlInjector;
import com.wys.api.exception.BizException;
import com.wys.spring.mybatisplus.generator.MybatisPlusGeneratorProperties;
import com.wys.utils.JsonUtils;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.transaction.TransactionFactory;
import org.apache.shardingsphere.shardingjdbc.jdbc.adapter.AbstractDataSourceAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.AutoConfigureBefore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.bind.BindContext;
import org.springframework.boot.context.properties.bind.BindHandler;
import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.boot.context.properties.source.ConfigurationPropertyName;
import org.springframework.cloud.endpoint.event.RefreshEvent;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationListener;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.env.Environment;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.jdbc.datasource.lookup.MapDataSourceLookup;
import org.springframework.transaction.TransactionManager;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

import javax.sql.DataSource;
import java.io.StringReader;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.function.Consumer;

@Configuration
@AutoConfigureBefore(value = {MybatisPlusAutoConfiguration.class, DataSourceAutoConfiguration.class, DruidDataSourceAutoConfigure.class})
@EnableConfigurationProperties(value = {MultipleDataSourceConfig.class, MybatisPlusProperties.class, MybatisPlusGeneratorProperties.class})
public class SpringMultipleDataSourceConfiguration implements BeanDefinitionRegistryPostProcessor, EnvironmentAware, ApplicationContextAware, ApplicationListener<RefreshEvent> {

    private static final String PreFix = "multiple.data-source";

    private static final String MASTER = "master";

    private static final Logger logger = LoggerFactory.getLogger(SpringMultipleDataSourceConfiguration.class);

    private static final Map<String, DruidDataSource> multipleDataSource = new ConcurrentHashMap<>();

    private Environment environment;

    private ApplicationContext applicationContext;

    private ConfigurableListableBeanFactory configurableListableBeanFactory;

    private BeanDefinitionRegistry beanDefinitionRegistry;

    private MultipleDataSourceConfig multipleDataSourceConfig;


    @Bean(name = "dynamicDataSource")
    public DynamicDataSource dynamicDataSource(MultipleDataSourceConfig multipleDataSourceConfig, ObjectProvider<DataSource[]> dataSourceObjectProvider) {
        this.multipleDataSourceConfig = multipleDataSourceConfig;
        return buildDynamicDataSource(multipleDataSourceConfig, dataSourceObjectProvider);
    }


    @Bean
    public DataSourcePointcutAdvisor dataSourcePointcutAdvisor() {
        return new DataSourcePointcutAdvisor(new DataSourceInterceptor());
    }

    @Bean(name = "dynamicJdbcTemplate")
    public JdbcTemplate jdbcTemplate(DynamicDataSource abstractRoutingDataSource) {
        return new JdbcTemplate(abstractRoutingDataSource);
    }

    @Bean
    public TransactionManager transactionManager(@Qualifier("dynamicDataSource") DynamicDataSource abstractRoutingDataSource) {
        return new DataSourceTransactionManager(abstractRoutingDataSource);
    }


    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry beanDefinitionRegistry) throws BeansException {
        this.beanDefinitionRegistry = beanDefinitionRegistry;
        buildDruidDataSource(beanDefinitionRegistry);
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {
        this.configurableListableBeanFactory = configurableListableBeanFactory;
    }

    @Bean
    @Primary
    public MybatisSqlSessionFactoryBean mybatisSqlSessionFactoryBean(@Qualifier("dynamicDataSource") DynamicDataSource abstractRoutingDataSource, MybatisPlusProperties properties, ObjectProvider<Interceptor[]> interceptorsProvider) throws Exception {
        return buildSqlSessionFactoryBean(null, abstractRoutingDataSource, properties, interceptorsProvider);
    }

    @Bean(name = "sqlSessionFactory")
    @ConditionalOnBean
    public SqlSessionFactory sqlSessionFactory(MybatisSqlSessionFactoryBean mybatisSqlSessionFactoryBean) throws Exception {
        return mybatisSqlSessionFactoryBean.getObject();
    }

    public DynamicDataSource buildDynamicDataSource(MultipleDataSourceConfig multipleDataSourceConfig, ObjectProvider<DataSource[]> dataSourceObjectProvider) {
        DynamicDataSource dynamicDataSource = new DynamicDataSource(multipleDataSourceConfig);
        if (multipleDataSource.isEmpty()) {
            throw new BizException("默认数据源不能为空");
        }
        Map<Object, Object> objectMap = new HashMap<>();
        multipleDataSource.keySet().forEach(k -> {
            if ("master".equalsIgnoreCase(k)) {
                dynamicDataSource.setDefaultTargetDataSource(multipleDataSource.get(k));
            }
            if (!ObjectUtils.isEmpty(dataSourceObjectProvider.getIfAvailable())) {
                for (DataSource dataSource : dataSourceObjectProvider.getIfAvailable()) {
                    if (dataSource instanceof DynamicDataSource) {
                        continue;
                    }
                    if (dataSource instanceof AbstractDataSourceAdapter) {
                        objectMap.put("shardingJDBC", dataSource);
                    }
                    if (dataSource instanceof AbstractDataSource) {
                        objectMap.put("hutool", dataSource);
                    }
                }
            }
            objectMap.put(k, multipleDataSource.get(k));
        });
        dynamicDataSource.setTargetDataSources(objectMap);
        return dynamicDataSource;
    }

    @Bean
    public SqlExecutor sqlExecutor(DynamicDataSource dynamicDataSource) {
        DefaultSqlExecutor defaultSqlExecutor = new DefaultSqlExecutor();
        defaultSqlExecutor.setDataSource(dynamicDataSource);
        if (org.apache.commons.lang3.ObjectUtils.isNotEmpty(multipleDataSource)) {
            for (String string : multipleDataSource.keySet()) {
                defaultSqlExecutor.setDataSource(string, multipleDataSource.get(string));
            }
        }
        return defaultSqlExecutor;
    }

    private MybatisSqlSessionFactoryBean buildSqlSessionFactoryBean(MybatisSqlSessionFactoryBean factory, DataSource source, MybatisPlusProperties properties, ObjectProvider<Interceptor[]> interceptorsProvider) throws Exception {
        if (factory == null) {
            factory = new MybatisSqlSessionFactoryBean();
        }
        factory.setDataSource(source);
        MybatisConfiguration mybatisConfiguration = new MybatisConfiguration();
        factory.setConfiguration(mybatisConfiguration);
        factory.setFailFast(true);
        factory.setVfs(SpringBootVFS.class);
        if (StringUtils.hasText(properties.getConfigLocation())) {
            factory.setConfigLocation(new DefaultResourceLoader().getResource(properties.getConfigLocation()));
        }
        if (properties.getConfigurationProperties() != null) {
            factory.setConfigurationProperties(properties.getConfigurationProperties());
        }
        if (!ObjectUtils.isEmpty(interceptorsProvider.getIfAvailable())) {
            logger.warn("开始加载MybatisPlus插件:{}", (Object) interceptorsProvider.getIfAvailable());
            factory.setPlugins(interceptorsProvider.getIfAvailable());
        }
        if (StringUtils.hasLength(properties.getTypeAliasesPackage())) {
            factory.setTypeAliasesPackage(properties.getTypeAliasesPackage());
        }
        if (properties.getTypeAliasesSuperType() != null) {
            factory.setTypeAliasesSuperType(properties.getTypeAliasesSuperType());
        }

        if (StringUtils.hasLength(properties.getTypeHandlersPackage())) {
            factory.setTypeHandlersPackage(properties.getTypeHandlersPackage());
        }
        if (!ObjectUtils.isEmpty(properties.resolveMapperLocations())) {
            factory.setMapperLocations(properties.resolveMapperLocations());
        }
        Objects.requireNonNull(factory);
        this.getBeanThen(TransactionFactory.class, factory::setTransactionFactory);
        GlobalConfig globalConfig = properties.getGlobalConfig();
        Objects.requireNonNull(globalConfig);
        this.getBeanThen(MetaObjectHandler.class, globalConfig::setMetaObjectHandler);
        this.getBeansThen(IKeyGenerator.class, (i) -> globalConfig.getDbConfig().setKeyGenerators(i));
        Objects.requireNonNull(globalConfig);
        this.getBeanThen(ISqlInjector.class, globalConfig::setSqlInjector);
        Objects.requireNonNull(globalConfig);
        this.getBeanThen(IdentifierGenerator.class, globalConfig::setIdentifierGenerator);
        globalConfig.setSqlInjector(new MPJSqlInjector());
        factory.setGlobalConfig(globalConfig);
        return factory;
    }

    private <T> void getBeanThen(Class<T> clazz, Consumer<T> consumer) {
        if (this.applicationContext.getBeanNamesForType(clazz, false, false).length > 0) {
            consumer.accept(this.applicationContext.getBean(clazz));
        }

    }

    private <T> void getBeansThen(Class<T> clazz, Consumer<List<T>> consumer) {
        if (this.applicationContext.getBeanNamesForType(clazz, false, false).length > 0) {
            Map<String, T> beansOfType = this.applicationContext.getBeansOfType(clazz);
            List<T> clazzList = new ArrayList();
            beansOfType.forEach((k, v) -> {
                clazzList.add(v);
            });
            consumer.accept(clazzList);
        }

    }

    public void buildDruidDataSource(BeanDefinitionRegistry beanDefinitionRegistry) {
        Map map = Binder.get(environment, new BindHandler() {
            @Override
            public <T> Bindable<T> onStart(ConfigurationPropertyName name, Bindable<T> target, BindContext context) {
                logger.warn("开始绑定属性:{}", name.toString());
                return BindHandler.super.onStart(name, target, context);
            }

            @Override
            public Object onSuccess(ConfigurationPropertyName name, Bindable<?> target, BindContext context, Object result) {
                logger.warn("属性绑定成功:{}", name.toString());
                return result;
            }
        }).bind(PreFix, Bindable.of(Map.class)).get();
        if (map.isEmpty()) {
            throw new BizException("引用了SpringMultipleDataSourceConfiguration就一定要配置MultipleDataSourceConfig");
        }
        map.keySet().forEach(m -> {
            //DataSourceBuilder
            DruidDataSource dataSources = Binder.get(environment).bind(PreFix + "." + m, DruidDataSource.class).get();
            //System.out.println("===================:" + dataSources.getUrl() + "============userName" + dataSources.getName());
            BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(DruidDataSource.class);
            if (MASTER.equalsIgnoreCase((String) m)) {
                beanDefinitionBuilder.setPrimary(true);
            }
            beanDefinitionRegistry.registerBeanDefinition((String) m, beanDefinitionBuilder.getBeanDefinition());
            DruidDataSource druidDataSource = applicationContext.getBean((String) m, DruidDataSource.class);
            BeanUtils.copyProperties(dataSources, druidDataSource);
            multipleDataSource.put((String) m, druidDataSource);

        });

    }

    @Override
    public void setEnvironment(Environment environment) {
        this.environment = environment;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    /**
     * 监听springCloud发布的配置刷新事件
     *
     * @param event the event to respond to
     */
    @Override
    public void onApplicationEvent(RefreshEvent event) {
        //对springCloud发布的刷新事件进行处理
        if (event.getSource() instanceof AbstractSharedListener && multipleDataSourceConfig.isEnableRefresh()) {
            NacosConfigManager nacosConfigManager = applicationContext.getBean(NacosConfigManager.class);
            ConfigService configService = nacosConfigManager.getConfigService();
            List<NacosConfigProperties.Config> configs = nacosConfigManager.getNacosConfigProperties().getExtensionConfigs();
            try {
                //调用NACOS CLIENT API 获取配置
                String config = configService.getConfig(configs.get(0).getDataId(), configs.get(0).getGroup(), 3000L);
                logger.warn("拉取nacos配置:{}", config);
                Map<String, Object> dict = YamlUtil.load(new StringReader(config));
                Map<String, DruidDataSource> map = JSON.parseObject(JsonUtils.object2Json(dict)).getJSONObject("multiple").getJSONObject("data-source").to(new TypeReference<Map<String, DruidDataSource>>() {
                });
                Map<String, DruidDataSource> dataSourceMap = checkDataSourceChange(map);
                multipleDataSource.clear();
                multipleDataSource.putAll(map);
                DynamicDataSource dynamicDataSource = applicationContext.getBean(DynamicDataSource.class);
                Map<Object, Object> objectMap = new HashMap<>(map);
                String[] beans = applicationContext.getBeanNamesForType(DruidDataSource.class);
                //如果数据源没有发生变动，就直接return
                if (multipleDataSource.size() == beans.length && dataSourceMap.isEmpty()) {
                    return;
                }
                logger.warn("缓存中的数据源:{},springBean中的数据源:{}", JsonUtils.object2Json(multipleDataSource.keySet()), JsonUtils.object2Json(beans));
                for (String bean : beans) {
                    //至少保留一个主数据源
                    if (beanDefinitionRegistry.getBeanDefinition(bean).isPrimary()) {
                        continue;
                    }
                    if (multipleDataSource.keySet().size() < beans.length && multipleDataSource.get(bean) == null) {
                        logger.warn("检测到多余的dataSource,删除BeanName:{}", bean);
                        applicationContext.getBean(bean, DruidDataSource.class).close();
                        dynamicDataSource.getResolvedDataSources().remove(bean);
                        beanDefinitionRegistry.removeBeanDefinition(bean);
                    }
                }
                if (!dataSourceMap.isEmpty()) {
                    dataSourceMap.keySet().forEach(k -> {
                        //如果DruidSource在bean容器中已经存在了，只需要刷新就好
                        if (k.equals(MASTER) || Arrays.asList(applicationContext.getBeanNamesForType(DruidDataSource.class)).contains(k)) {
                            DruidDataSource druidDataSource = applicationContext.getBean(k, DruidDataSource.class);
                            refreshDataSource(druidDataSource, dataSourceMap.get(k));
                            logger.warn("刷新数据源:{},JDBC_URL:{}", k, druidDataSource.getUrl());
                        } else {
                            //新增的数据源需要注册到spring容器中
                            beanDefinitionRegistry.registerBeanDefinition(k, BeanDefinitionBuilder.genericBeanDefinition(DruidDataSource.class).getBeanDefinition());
                            DruidDataSource druidDataSource = applicationContext.getBean(k, DruidDataSource.class);
                            objectMap.put(k, druidDataSource);
                            refreshDataSource(druidDataSource, dataSourceMap.get(k));
                            druidDataSource.resetStat();
                            logger.warn("新增数据源:{},JDBC_URL:{}", k, druidDataSource.getUrl());
                        }
                    });
                    dynamicDataSource.setTargetDataSources(objectMap);
                    dynamicDataSource.setDataSourceLookup(new MapDataSourceLookup(new HashMap<>(dataSourceMap)));
                    //刷新动态数据源中的resolvedDataSources
                    dynamicDataSource.afterPropertiesSet();
                }

            } catch (NacosException e) {
                throw new RuntimeException(e);
            } catch (Exception e) {
                logger.error("异常:", e);
            }

        }


    }


    private void initDruidDataSource(DruidDataSource dataSource) {

        try {

            //设置默认参数
            if (dataSource.getMaxActive() == 8 || dataSource.getMaxActive() == 5) {
                dataSource.setMaxActive(100);
            }

            if (dataSource.getInitialSize() == 0 || dataSource.getInitialSize() == 1) {
                dataSource.setInitialSize(10);
            }


            if (!dataSource.isPoolPreparedStatements()) {
                dataSource.setMaxPoolPreparedStatementPerConnectionSize(5);
            }

            //设置获取连接的最大等待时间为10s
            if (dataSource.getMaxWait() < 0 || dataSource.getMaxWait() > 5000L) {
                dataSource.setMaxWait(5000L);
            }

            if (dataSource.getValidationQuery() == null) {
                dataSource.setValidationQuery("SELECT 'x'");
            }

            if (dataSource.getValidationQueryTimeout() < 0) {
                dataSource.setValidationQueryTimeout(0);
            }

        } catch (Exception e) {
            logger.error("初始化druid数据源发生错误, ex: " + e.getMessage());
        }

    }


    /**
     * 刷新数据源
     *
     * @param druidDataSource    旧的数据源
     * @param newDruidDataSource 新的数据源
     */
    public void refreshDataSource(DruidDataSource druidDataSource, DruidDataSource newDruidDataSource) {
        ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(10000, new NameThreadFactory("nacos-config-refresh druidDataSource" + newDruidDataSource.getVersion()));
        try {
            druidDataSource.restart();
        } catch (SQLException e) {
            logger.error("数据源刷新失败");
            throw new RuntimeException(e);
        }
        druidDataSource.setCreateScheduler(scheduledThreadPoolExecutor);
        druidDataSource.setPassword(newDruidDataSource.getPassword());
        druidDataSource.setUrl(newDruidDataSource.getUrl());
        druidDataSource.setUsername(newDruidDataSource.getUsername());
        druidDataSource.setResetStatEnable(true);
        druidDataSource.setDriverClassName(newDruidDataSource.getDriverClassName());
        initDruidDataSource(druidDataSource);

    }

    /**
     * 检查数据源是否和初始化加载时的数据源一致，如果一致就不做刷新
     *
     * @param newDataSource
     * @return
     */
    public Map<String, DruidDataSource> checkDataSourceChange(Map<String, DruidDataSource> newDataSource) {
        Map<String, DruidDataSource> dataSourceMap = new ConcurrentHashMap<>(16);
        if (!multipleDataSource.isEmpty() && !newDataSource.isEmpty()) {
            newDataSource.keySet().forEach(k -> {
                if (multipleDataSource.containsKey(k)) {
                    //如果检测到新数据源与静态缓存中的旧数据源不一致就对不一致缓存就行刷新
                    if (!multipleDataSource.get(k).getUrl().equals(newDataSource.get(k).getUrl()) ||
                            !multipleDataSource.get(k).getPassword().equals(newDataSource.get(k).getPassword())
                            || !multipleDataSource.get(k).getUsername().equals(newDataSource.get(k).getUsername())
                            || !multipleDataSource.get(k).getDriverClassName().equals(newDataSource.get(k).getDriverClassName())) {
                        dataSourceMap.put(k, newDataSource.get(k));
                    }
                } else {
                    dataSourceMap.put(k, newDataSource.get(k));
                }
            });
        }
        return dataSourceMap;
    }


}
